import tensorflow as tf


def positional_embedding(pos_seq, inv_freq, bsz = None):
    sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq)
    pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
    if bsz is not None:
        return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
    else:
        return pos_emb[:, None, :]


def positionwise_FF(inp, d_model, d_inner, kernel_initializer, scope = 'ff'):
    output = inp
    with tf.variable_scope(scope):
        output = tf.layers.dense(
            inp,
            d_inner,
            activation = tf.nn.relu,
            kernel_initializer = kernel_initializer,
            name = 'layer_1',
        )
        output = tf.layers.dense(
            output,
            d_model,
            kernel_initializer = kernel_initializer,
            name = 'layer_2',
        )
        output = tf.contrib.layers.layer_norm(
            output + inp, begin_norm_axis = -1
        )
    return output


def rel_shift(x):
    x_size = tf.shape(x)

    x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
    x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
    x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
    x = tf.reshape(x, x_size)

    return x


def rel_multihead_attn(
    w,
    r,
    r_w_bias,
    r_r_bias,
    attn_mask,
    mems,
    d_model,
    n_head,
    d_head,
    kernel_initializer,
    scope = 'rel_attn',
):
    scale = 1 / (d_head ** 0.5)
    with tf.variable_scope(scope):
        qlen = tf.shape(w)[0]
        rlen = tf.shape(r)[0]
        bsz = tf.shape(w)[1]

        cat = (
            tf.concat([mems, w], 0)
            if mems is not None and mems.shape.ndims > 1
            else w
        )
        w_heads = tf.layers.dense(
            cat,
            3 * n_head * d_head,
            use_bias = False,
            kernel_initializer = kernel_initializer,
            name = 'qkv',
        )
        r_head_k = tf.layers.dense(
            r,
            n_head * d_head,
            use_bias = False,
            kernel_initializer = kernel_initializer,
            name = 'r',
        )

        w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
        w_head_q = w_head_q[-qlen:]

        klen = tf.shape(w_head_k)[0]

        w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
        w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
        w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])

        r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])

        rw_head_q = w_head_q + r_w_bias
        rr_head_q = w_head_q + r_r_bias

        AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
        BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
        BD = rel_shift(BD)

        attn_score = (AC + BD) * scale
        attn_mask_t = attn_mask[:, :, None, None]
        attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t

        attn_prob = tf.nn.softmax(attn_score, 1)
        attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
        size_t = tf.shape(attn_vec)
        attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])

        attn_out = tf.layers.dense(
            attn_vec,
            d_model,
            use_bias = False,
            kernel_initializer = kernel_initializer,
            name = 'o',
        )

        output = tf.contrib.layers.layer_norm(
            attn_out + w, begin_norm_axis = -1
        )
    return output


def embedding_lookup(lookup_table, x):
    return tf.nn.embedding_lookup(lookup_table, x)


def mask_adaptive_embedding_lookup(
    x,
    n_token,
    d_embed,
    d_proj,
    cutoffs,
    initializer,
    proj_initializer,
    div_val = 1,
    proj_same_dim = True,
    scope = 'adaptive_embed',
    **kwargs
):
    emb_scale = d_proj ** 0.5
    with tf.variable_scope(scope):
        if div_val == 1:
            lookup_table = tf.get_variable(
                'lookup_table', [n_token, d_embed], initializer = initializer
            )
            y = embedding_lookup(lookup_table, x)
            if d_proj != d_embed:
                proj_W = tf.get_variable(
                    'proj_W', [d_embed, d_proj], initializer = proj_initializer
                )
                y = tf.einsum('ibe,ed->ibd', y, proj_W)
            else:
                proj_W = None
            ret_params = [lookup_table, proj_W]
        else:
            tables, projs = [], []
            cutoff_ends = [0] + cutoffs + [n_token]
            x_size = tf.shape(x)
            y = tf.zeros([x_size[0], x_size[1], d_proj])
            for i in range(len(cutoff_ends) - 1):
                with tf.variable_scope('cutoff_{}'.format(i)):
                    l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
                    mask = (x >= l_idx) & (x < r_idx)
                    cur_x = tf.boolean_mask(x, mask) - l_idx
                    cur_d_embed = d_embed // (div_val ** i)
                    lookup_table = tf.get_variable(
                        'lookup_table',
                        [r_idx - l_idx, cur_d_embed],
                        initializer = initializer,
                    )
                    cur_y = embedding_lookup(lookup_table, cur_x)
                    if d_proj == cur_d_embed and not proj_same_dim:
                        proj_W = None
                    else:
                        proj_W = tf.get_variable(
                            'proj_W',
                            [cur_d_embed, d_proj],
                            initializer = proj_initializer,
                        )
                        cur_y = tf.einsum('id,de->ie', cur_y, proj_W)
                    mask_idx = tf.to_int64(tf.where(mask))
                    y += tf.scatter_nd(
                        mask_idx, cur_y, tf.to_int64(tf.shape(y))
                    )
                    tables.append(lookup_table)
                    projs.append(proj_W)
            ret_params = [tables, projs]

    y *= emb_scale
    return y, ret_params


def mul_adaptive_embedding_lookup(
    x,
    n_token,
    d_embed,
    d_proj,
    cutoffs,
    initializer,
    proj_initializer,
    div_val = 1,
    perms = None,
    proj_same_dim = True,
    scope = 'adaptive_embed',
):
    """
  perms: If None, first compute W = W1 x W2 (projection for each bin),
      and then compute X x W (embedding lookup). If not None,
      use bin-based embedding lookup with max_bin_size defined by
      the shape of perms.
  """
    emb_scale = d_proj ** 0.5
    with tf.variable_scope(scope):
        if div_val == 1:
            lookup_table = tf.get_variable(
                'lookup_table', [n_token, d_embed], initializer = initializer
            )
            y = embedding_lookup(lookup_table, x)
            if d_proj != d_embed:
                proj_W = tf.get_variable(
                    'proj_W', [d_embed, d_proj], initializer = proj_initializer
                )
                y = tf.einsum('ibe,ed->ibd', y, proj_W)
            else:
                proj_W = None
            ret_params = [lookup_table, proj_W]
        else:
            tables, projs = [], []
            cutoff_ends = [0] + cutoffs + [n_token]
            x_size = tf.shape(x)
            if perms is None:
                cat_lookup = []
            else:
                cat_lookup = tf.zeros([x_size[0], x_size[1], d_proj])
            for i in range(len(cutoff_ends) - 1):
                with tf.variable_scope('cutoff_{}'.format(i)):
                    l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
                    cur_d_embed = d_embed // (div_val ** i)
                    lookup_table = tf.get_variable(
                        'lookup_table',
                        [r_idx - l_idx, cur_d_embed],
                        initializer = initializer,
                    )
                    if cur_d_embed == d_proj and not proj_same_dim:
                        proj_W = None
                    else:
                        proj_W = tf.get_variable(
                            'proj_W',
                            [cur_d_embed, d_proj],
                            initializer = proj_initializer,
                        )
                    if perms is None:
                        cat_lookup.append(
                            tf.einsum('ie,ed->id', lookup_table, proj_W)
                        )
                    else:
                        # speed up the computation of the first bin
                        # also save some meory
                        if i == 0:
                            cur_y = embedding_lookup(
                                lookup_table, tf.minimum(x, r_idx - 1)
                            )
                            if proj_W is not None:
                                cur_y = tf.einsum('ibe,ed->ibd', cur_y, proj_W)
                            cur_y *= perms[i][:, :, None]
                            cat_lookup += cur_y
                        else:
                            cur_x = tf.einsum(
                                'ib,ibk->k', tf.to_float(x - l_idx), perms[i]
                            )
                            cur_x = tf.to_int32(cur_x)
                            cur_y = embedding_lookup(lookup_table, cur_x)
                            if proj_W is not None:
                                cur_y = tf.einsum('ke,ed->kd', cur_y, proj_W)
                            cat_lookup += tf.einsum(
                                'kd,ibk->ibd', cur_y, perms[i]
                            )
                    tables.append(lookup_table)
                    projs.append(proj_W)
            if perms is None:
                cat_lookup = tf.concat(cat_lookup, 0)
                y = embedding_lookup(cat_lookup, x)
            else:
                y = cat_lookup
            ret_params = [tables, projs]

    y *= emb_scale
    return y, ret_params


def mask_adaptive_logsoftmax(
    hidden,
    target,
    n_token,
    d_embed,
    d_proj,
    cutoffs,
    params,
    tie_projs,
    initializer = None,
    proj_initializer = None,
    div_val = 1,
    scope = 'adaptive_softmax',
    proj_same_dim = True,
    return_mean = True,
    **kwargs
):
    def _logit(x, W, b, proj):
        y = x
        if proj is not None:
            y = tf.einsum('ibd,ed->ibe', y, proj)
        return tf.einsum('ibd,nd->ibn', y, W) + b

    params_W, params_projs = params[0], params[1]

    def _gather_logprob(logprob, target):
        lp_size = tf.shape(logprob)
        r = tf.range(lp_size[0])
        idx = tf.stack([r, target], 1)
        return tf.gather_nd(logprob, idx)

    with tf.variable_scope(scope):
        if len(cutoffs) == 0:
            softmax_b = tf.get_variable(
                'bias', [n_token], initializer = tf.zeros_initializer()
            )
            output = _logit(hidden, params_W, softmax_b, params_projs)
            nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels = target, logits = output
            )
        else:
            cutoff_ends = [0] + cutoffs + [n_token]
            nll = tf.zeros_like(target, dtype = tf.float32)
            for i in range(len(cutoff_ends) - 1):
                with tf.variable_scope('cutoff_{}'.format(i)):
                    l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
                    mask = (target >= l_idx) & (target < r_idx)
                    mask_idx = tf.where(mask)
                    cur_target = tf.boolean_mask(target, mask) - l_idx
                    cur_d_embed = d_embed // (div_val ** i)

                    if div_val == 1:
                        cur_W = params_W[l_idx:r_idx]
                    else:
                        cur_W = params_W[i]
                    cur_b = tf.get_variable(
                        'b',
                        [r_idx - l_idx],
                        initializer = tf.zeros_initializer(),
                    )
                    if tie_projs[i]:
                        if div_val == 1:
                            cur_proj = params_projs
                        else:
                            cur_proj = params_projs[i]
                    else:
                        if (
                            div_val == 1 or not proj_same_dim
                        ) and d_proj == cur_d_embed:
                            cur_proj = None
                        else:
                            cur_proj = tf.get_variable(
                                'proj',
                                [cur_d_embed, d_proj],
                                initializer = proj_initializer,
                            )
                    if i == 0:
                        cluster_W = tf.get_variable(
                            'cluster_W',
                            [len(cutoffs), d_embed],
                            initializer = tf.zeros_initializer(),
                        )
                        cluster_b = tf.get_variable(
                            'cluster_b',
                            [len(cutoffs)],
                            initializer = tf.zeros_initializer(),
                        )
                        cur_W = tf.concat([cur_W, cluster_W], 0)
                        cur_b = tf.concat([cur_b, cluster_b], 0)

                        head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
                        head_logprob = tf.nn.log_softmax(head_logit)
                        cur_head_logprob = tf.boolean_mask(head_logprob, mask)
                        cur_logprob = _gather_logprob(
                            cur_head_logprob, cur_target
                        )
                    else:
                        cur_head_logprob = tf.boolean_mask(head_logprob, mask)
                        cur_hidden = tf.boolean_mask(hidden, mask)
                        tail_logit = tf.squeeze(
                            _logit(cur_hidden[None], cur_W, cur_b, cur_proj), 0
                        )
                        tail_logprob = tf.nn.log_softmax(tail_logit)
                        cur_logprob = cur_head_logprob[
                            :, cutoff_ends[1] + i - 1
                        ] + _gather_logprob(tail_logprob, cur_target)
                    nll += tf.scatter_nd(
                        mask_idx, -cur_logprob, tf.to_int64(tf.shape(nll))
                    )
    if return_mean:
        nll = tf.reduce_mean(nll)
    return nll


def mul_adaptive_logsoftmax(
    hidden,
    target,
    n_token,
    d_embed,
    d_proj,
    cutoffs,
    params,
    tie_projs,
    initializer = None,
    proj_initializer = None,
    div_val = 1,
    perms = None,
    proj_same_dim = True,
    scope = 'adaptive_softmax',
    **kwargs
):
    def _logit(x, W, b, proj):
        y = x
        if x.shape.ndims == 3:
            if proj is not None:
                y = tf.einsum('ibd,ed->ibe', y, proj)
            return tf.einsum('ibd,nd->ibn', y, W) + b
        else:
            if proj is not None:
                y = tf.einsum('id,ed->ie', y, proj)
            return tf.einsum('id,nd->in', y, W) + b

    params_W, params_projs = params[0], params[1]

    with tf.variable_scope(scope):
        if len(cutoffs) == 0:
            softmax_b = tf.get_variable(
                'bias', [n_token], initializer = tf.zeros_initializer()
            )
            output = _logit(hidden, params_W, softmax_b, params_projs)
            nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels = target, logits = output
            )
            nll = tf.reduce_mean(nll)
        else:
            total_loss, total_cnt = 0, 0
            cutoff_ends = [0] + cutoffs + [n_token]
            for i in range(len(cutoff_ends) - 1):
                with tf.variable_scope('cutoff_{}'.format(i)):
                    l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]

                    cur_d_embed = d_embed // (div_val ** i)

                    if div_val == 1:
                        cur_W = params_W[l_idx:r_idx]
                    else:
                        cur_W = params_W[i]
                    cur_b = tf.get_variable(
                        'b',
                        [r_idx - l_idx],
                        initializer = tf.zeros_initializer(),
                    )
                    if tie_projs[i]:
                        if div_val == 1:
                            cur_proj = params_projs
                        else:
                            cur_proj = params_projs[i]
                    else:
                        if (
                            div_val == 1 or not proj_same_dim
                        ) and d_proj == cur_d_embed:
                            cur_proj = None
                        else:
                            cur_proj = tf.get_variable(
                                'proj',
                                [cur_d_embed, d_proj],
                                initializer = proj_initializer,
                            )

                    if i == 0:
                        cluster_W = tf.get_variable(
                            'cluster_W',
                            [len(cutoffs), d_embed],
                            initializer = tf.zeros_initializer(),
                        )
                        cluster_b = tf.get_variable(
                            'cluster_b',
                            [len(cutoffs)],
                            initializer = tf.zeros_initializer(),
                        )
                        cur_W = tf.concat([cur_W, cluster_W], 0)
                        cur_b = tf.concat([cur_b, cluster_b], 0)

                        head_logit = _logit(hidden, cur_W, cur_b, cur_proj)

                        head_target = kwargs.get('head_target')
                        head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
                            labels = head_target, logits = head_logit
                        )

                        masked_loss = head_nll * perms[i]
                        total_loss += tf.reduce_sum(masked_loss)
                        total_cnt += tf.reduce_sum(perms[i])

                        # head_logprob = tf.nn.log_softmax(head_logit)

                        # final_logprob = head_logprob * perms[i][:, :, None]
                        # final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
                        # total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
                        # total_cnt += tf.reduce_sum(perms[i])
                    else:
                        cur_head_nll = tf.einsum(
                            'ib,ibk->k', head_nll, perms[i]
                        )

                        cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i])
                        tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj)

                        tail_target = tf.einsum(
                            'ib,ibk->k', tf.to_float(target - l_idx), perms[i]
                        )
                        tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
                            labels = tf.to_int32(tail_target),
                            logits = tail_logit,
                        )

                        sum_nll = cur_head_nll + tail_nll
                        mask = tf.reduce_sum(perms[i], [0, 1])

                        masked_loss = sum_nll * mask
                        total_loss += tf.reduce_sum(masked_loss)
                        total_cnt += tf.reduce_sum(mask)

            nll = total_loss / total_cnt

    return nll


def _create_mask(qlen, mlen, same_length = False):
    attn_mask = tf.ones([qlen, qlen])
    mask_u = tf.matrix_band_part(attn_mask, 0, -1)
    mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
    attn_mask_pad = tf.zeros([qlen, mlen])
    ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
    if same_length:
        mask_l = tf.matrix_band_part(attn_mask, -1, 0)
        ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
    return ret


def _cache_mem(curr_out, prev_mem, mem_len = None):
    if mem_len is None or prev_mem is None:
        new_mem = curr_out
    elif mem_len == 0:
        return prev_mem
    else:
        new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]

    return tf.stop_gradient(new_mem)


def transformer(
    dec_inp,
    mems,
    n_token,
    n_layer,
    d_model,
    d_embed,
    n_head,
    d_head,
    d_inner,
    initializer,
    proj_initializer = None,
    mem_len = None,
    cutoffs = [],
    div_val = 1,
    tie_projs = [],
    same_length = False,
    clamp_len = -1,
    untie_r = False,
    proj_same_dim = True,
    scope = 'transformer',
):
    """
  cutoffs: a list of python int. Cutoffs for adaptive softmax.
  tie_projs: a list of python bools. Whether to tie the projections.
  perms: a list of tensors. Each tensor should of size [len, bsz, bin_size].
        Only used in the adaptive setting.
  """
    new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable(
                'r_w_bias', [n_layer, n_head, d_head], initializer = initializer
            )
            r_r_bias = tf.get_variable(
                'r_r_bias', [n_layer, n_head, d_head], initializer = initializer
            )
        else:
            r_w_bias = tf.get_variable(
                'r_w_bias', [n_head, d_head], initializer = initializer
            )
            r_r_bias = tf.get_variable(
                'r_r_bias', [n_head, d_head], initializer = initializer
            )

        qlen = tf.shape(dec_inp)[0]
        mlen = tf.shape(mems[0])[0] if mems is not None else 0
        klen = mlen + qlen

        if proj_initializer is None:
            proj_initializer = initializer
        lookup_fn = mask_adaptive_embedding_lookup
        embeddings, shared_params = lookup_fn(
            x = dec_inp,
            n_token = n_token,
            d_embed = d_embed,
            d_proj = d_model,
            cutoffs = cutoffs,
            initializer = initializer,
            proj_initializer = proj_initializer,
            div_val = div_val,
            proj_same_dim = proj_same_dim,
        )

        attn_mask = _create_mask(qlen, mlen, same_length)

        pos_seq = tf.range(klen - 1, -1, -1.0)
        if clamp_len > 0:
            pos_seq = tf.minimum(pos_seq, clamp_len)
        inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model))
        pos_emb = positional_embedding(pos_seq, inv_freq)

        if mems is None:
            mems = [None] * n_layer
        output = embeddings
        for i in range(n_layer):
            # cache new mems
            new_mems.append(_cache_mem(output, mems[i], mem_len))

            with tf.variable_scope('layer_{}'.format(i)):
                output = rel_multihead_attn(
                    w = output,
                    r = pos_emb,
                    r_w_bias = r_w_bias if not untie_r else r_w_bias[i],
                    r_r_bias = r_r_bias if not untie_r else r_r_bias[i],
                    attn_mask = attn_mask,
                    mems = mems[i],
                    d_model = d_model,
                    n_head = n_head,
                    d_head = d_head,
                    kernel_initializer = initializer,
                )
                output = positionwise_FF(
                    inp = output,
                    d_model = d_model,
                    d_inner = d_inner,
                    kernel_initializer = initializer,
                )

        return output, new_mems
